from typing import Tuple, Optional, Any, List, Dict, Union
from torch.optim import Optimizer
from torch.nn import Parameter
import torch

class ZeroRedundancyOptimizer(Optimizer):
    def __init__(
        self, params, optim: Optimizer = ..., group: Optional[Any] = ..., bucket_cap_kb: int = ..., **default: Any
    ) -> None: ...
    def partition_parameters(self) -> List[List[dict]]: ...
    def per_device_params(self) -> Dict[torch.device, List[List[Parameter]]]: ...
    def param_to_rank(self) -> Dict[torch.Tensor, int]: ...
    def local_state_dict(self) -> Dict[Any, Any]: ...
    def consolidate_state_dict(self, recipient_rank: int = 0) -> None: ...
    def load_local_state_dict(self, state_dict: Dict[Any, Any]) -> None: ...
    def clip_grad_norm(self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0) -> torch.Tensor: ...
